import ei
import os, sys
import json
import warnings


def get_icon_dataset(image_size, random_color=True, train_ratio=0.9):
    from iconflow.dataset import IconContourDataset

    root = f'datasets/icon4/data/in_memory'
    
    n = min(train_ratio, 0.9)

    splits = {
        'train': IconContourDataset(root, image_size, split=(0, n)),
        'test': IconContourDataset(root, image_size, split=(n, 1.0))
    }

    if train_ratio > 0.0:
        splits['aug_train'] = IconContourDataset(
            root, image_size,
            random_crop=True,
            random_transpose=True,
            random_color=random_color,
            split=(0, train_ratio)
        )
        
    return splits

def get_net():
    import torch.nn as nn
    
    return 'Net'


def get_opt(net):
    return 'Opt'

def train(
    image_size=256,
    
    device='cpu',
    batch_size=64,
    num_workers=4,
    output_dir='output/temp/test',
    train_ratio=0.9,
    
    log_int=100,
    sample_int=1000,
    save_int=1000,
    save_iters=[],
    end_iter=300000,
    
    random_color=False,
):
    
    assert image_size == 256
    
    ei.patch()
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    config = locals().copy()
    config['argv'] = sys.argv.copy()
    warnings.filterwarnings("ignore")
    
    import random
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard.writer import SummaryWriter
    from torchvision.utils import make_grid
    from tqdm import tqdm

    from iconflow.utils import cycle_iter, random_sampler
    
    os.makedirs(output_dir, exist_ok=True)
    device = torch.device(device)
    writer = SummaryWriter(output_dir)


    '''
    Dataset
    '''
    
    d = get_icon_dataset(
        image_size,
        random_color=random_color,
        train_ratio=train_ratio,
    )
    
    train_loader = DataLoader(
        d['aug_train'],
        batch_size=batch_size,
        sampler=random_sampler(len(d['aug_train'])),
        pin_memory=(device.type == 'cuda'),
        num_workers=num_workers
    )
    

    '''
    Model
    '''

    net = get_net(e_dim, s_dim)
    opt = optim.Adam(net.parameters(), lr=1e-4)
    net.to(device)
    
    try:
        ckpt_path = os.path.join(output_dir, 'checkpoint.pt')
        state = torch.load(ckpt_path, map_location=device)
        net.load_state_dict(state['net'])
        opt.load_state_dict(state['opt'])
        it = state['it']
        print(f'loaded from checkpoint, it: {it}')
        del state
    except:
        it = 0
    
    
    '''
    Save config
    '''
    
    print(config)
    
    with open(os.path.join(output_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=2)
    
    
    '''
    Sample
    '''
    
    @torch.no_grad()
    def sample():
        net.eval()

        if image_size > 256:
            n_train, n_test = 1, 1
        elif image_size > 128:
            n_train, n_test = 2, 2
        else:
            n_train, n_test = 4, 4

        rng = random.Random(it % 300000 + 1337)
        x1, c1 = zip(*rng.choices(d['train'], k=n_train),
                     *rng.choices(d['test'], k=n_test))
        x1, c1 = map(torch.stack, (x1, c1))
        x1 = x1.to(device)
        c1 = c1.to(device)
        
        r12_list = []
        
        for i in range(len(x1)):
            x2 = x1.roll(i, 0)
            r12 = f'TODO{x2}'
            r12_list.append(r12)
        
        rows = torch.stack([x1, c1.expand_as(x1), *r12_list])
        images = rows.permute(1, 0, 2, 3, 4).reshape(-1, *rows.shape[-3:])
        
        return make_grid(images, nrow=len(rows))

    '''
    Save
    '''
    
    def save(add_postfix=False):
        if add_postfix:
            file_name = f'checkpoint_{it}.pt'
        else:
            file_name = 'checkpoint.pt'
        output_path = os.path.join(output_dir, file_name)

        torch.save({
            'net': net.state_dict(),
            'opt': opt.state_dict(),
            'it': it,
            'config': config,
        }, output_path)
    
    
    '''
    Train
    '''
    
    def step(X: torch.Tensor, C: torch.Tensor):
        x1 = X.to(device)
        c1 = C.to(device)
        
        log_dict = {}
        
        net.train()
        
        loss = 0
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        opt.zero_grad()
        
        return log_dict

    try:
        with tqdm(cycle_iter(train_loader), total=end_iter-it) as iter_loader:
            for X, C in iter_loader:
                if isinstance(end_iter, int) and it >= end_iter:
                    raise KeyboardInterrupt
                
                log_dict = step(X, C)
                
                iter_loader.set_postfix({'it': it, **log_dict})
 
                if it % log_int == 0:
                    for key, value in log_dict.items():
                        writer.add_scalar(f'training/{key}', value, it)
                
                if (it < 1000 and it % 200 == 0) or (it >= 1000 and it % sample_int == 0):
                    writer.add_image(f'sampling/sample', sample(), it)
                
                if it % save_int == 0:
                    save()
                
                if (it + 1) in save_iters:
                    save(add_postfix=True)
                
                it += 1

    except KeyboardInterrupt:
        print('saving checkpoint')
        save()


if __name__ == '__main__':
    import fire

    fire.Fire(train)
